import copy
import os

import numpy as np
import torch
from scipy.sparse import csr_matrix

from controlsnr import find_a_given_snr
from scipy.sparse import csr_matrix
import numpy as np
from typing import Optional, Union, Tuple, Dict

def simple_collate_fn(batch):
    """
    拼接 batch，适用于所有图大小一致的情况。
    返回：
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }
# 训练集
snr_train   = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]   # 保持不变
gamma_train = [0.5, 1.0, 2.0, 3.5, 5.0]                                      # 覆盖更不平衡到接近均匀
C_train     = [5.0, 10.0, 15.0, 20.0]                                        # 保持不变
per_cell_tr =  30

# 验证集 = 训练集的 midpoints（数量不变：9,4,3）
def midpoints(vals):
    return [(vals[i] + vals[i+1]) / 2 for i in range(len(vals)-1)]

snr_val   = midpoints(snr_train)    # 9
gamma_val = midpoints(gamma_train)  # 4 -> [0.75, 1.5, 2.75, 4.25]
C_val     = midpoints(C_train)      # 3 -> [7.5, 12.5, 17.5]
per_cell_v = 5

# 可选：快速检查验证点与训练点不重合
assert set(snr_val).isdisjoint(snr_train)
assert set(gamma_val).isdisjoint(gamma_train)
assert set(C_val).isdisjoint(C_train)


# —— 测试集（=200）
snr_test = (0.60, 1.00, 1.60, 2.20, 2.80)
gamma_test = (0.15, 0.60, 1.50, 3.00)
C_test = (10.0,)
per_cell_te = 1

def _normalize_theta_within_block(z: np.ndarray, theta: np.ndarray) -> np.ndarray:
    """
    让每个社区内 sum(theta_i) = 该社区规模（即社区内 theta 的平均值 = 1）。
    这样 θ 仅塑造“度异质性形状”，不改变平均度标定，便于用 ρ 精确控平均度。
    """
    z = np.asarray(z)
    theta = np.asarray(theta, dtype=float)
    for r in np.unique(z):
        idx = (z == r)
        s = theta[idx].sum()
        n_r = idx.sum()
        if s > 0:
            theta[idx] *= (n_r / s)
    return theta

class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val = 50,generative_model='SBM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                 num_examples_train=100, num_examples_test=10, num_examples_val=10):
        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        self.fixed_class_sizes = [
            (500, 500),
            (400, 600),
            (300, 700),
            (200, 800),
            (100, 900),
            (50, 950)
        ]

    def SBM(self, p, q, N):
        W = np.zeros((N, N))

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        n = N // 2

        W[:n, :n] = np.random.binomial(1, p, (n, n))
        W[n:, n:] = np.random.binomial(1, p, (N-n, N-n))
        W[:n, n:] = np.random.binomial(1, q, (n, N-n))
        W[n:, :n] = np.random.binomial(1, q, (N-n, n))
        W = W * (np.ones(N) - np.eye(N))
        W = np.maximum(W, W.transpose())

        perm = torch.randperm(N).numpy()
        blockA = perm < n
        labels = blockA * 2 - 1

        W_permed = W[perm]
        W_permed = W_permed[:, perm]
        return W_permed, labels


    def SBM_multiclass(self, p, q, N, n_classes):

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        prob_mat = np.ones((N, N)) * q_prime

        n = N // n_classes  # 基础类别大小
        remainder = N % n_classes  # 不能整除的剩余部分
        n_last = n + remainder  # 最后一类的大小

        # 先对整除部分进行块状分配
        for i in range(n_classes - 1):  # 处理前 n_classes-1 类
            prob_mat[i * n: (i + 1) * n, i * n: (i + 1) * n] = p_prime

        # 处理最后一类
        start_idx = (n_classes - 1) * n  # 最后一类的起始索引
        prob_mat[start_idx: start_idx + n_last, start_idx: start_idx + n_last] = p_prime

        # 生成邻接矩阵
        W = np.random.rand(N, N) < prob_mat
        W = W.astype(int)

        W = W * (np.ones(N) - np.eye(N))  # 移除自环
        W = np.maximum(W, W.transpose())  # 确保无向图

        # 随机打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成类别标签
        labels =np.minimum((perm // n) , n_classes - 1)

        W_permed = W[perm]
        W_permed = W_permed[:, perm]

        #计算P矩阵的特征向量
        prob_mat_permed = prob_mat[perm][:, perm]
        # np.fill_diagonal(prob_mat_permed, 0)  # 去除自环

        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top  # 返回前n_classes特征向量

    from scipy.sparse import csr_matrix

    def imbalanced_DCSBM_multiclass(self, B_prob, labels, theta, rho, *, rng=None, return_eigvecs=False, topk=8):
        """
        给定 (B_prob, labels, theta, rho) 采样一张 DCSBM 图。
        返回: W_dense(or sparse), labels, (可选) eigvecs_top
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        N = len(labels)
        k = B_prob.shape[0]
        labels = np.asarray(labels, dtype=int)
        theta = np.asarray(theta, dtype=float)

        # 为了直观与简洁，这里先走 dense 概率矩阵再伯努利采样（与你原逻辑一致）；
        # 若 N 很大，可改成按社区块做外积并分块比较随机数以节省内存。
        # 构造 s_i = sum_s B_prob[g_i,s] * m_s
        m = np.array([theta[labels == r].sum() for r in range(k)], dtype=float)
        Bg = B_prob[labels]  # (N, k) 挑每个节点所在行
        s = (Bg @ m)  # (N,)

        # P_ij = rho * theta_i * theta_j * B_prob[g_i, g_j]
        # 可由：P = rho * diag(theta) * Z*B_prob*Z^T * diag(theta)
        # 用广播构造 dense 概率矩阵（简洁版）
        # 先构一个 (N, N) 的块 B_gij：通过行列映射索引
        B_rows = B_prob[labels]  # (N, k)
        B_gij = B_rows[:, labels]  # (N, N) —— 等价 Z B Z^T 的取值

        P = rho * (theta[:, None] * theta[None, :]) * B_gij
        np.fill_diagonal(P, 0.0)
        # 数值安全：clip 到 [0,1]
        P = np.clip(P, 0.0, 1.0)

        # 采样
        A = (rng.random((N, N)) < P).astype(np.float32)
        # 无向图对称化
        A = np.triu(A, 1)
        A = A + A.T

        eigvecs_top = None
        if return_eigvecs:
            # 简单返回前 topk 个特征向量（可替换为 Bethe Hessian / 归一化拉普拉斯）
            w, v = np.linalg.eigh(A)
            idx = np.argsort(w)[::-1][:topk]
            eigvecs_top = v[:, idx].astype(np.float32, copy=False)

        return A, labels, eigvecs_top

    def prepare_data(self):
        def get_npz_dataset(path, mode, *, snr_grid, gamma_grid, C_grid, per_cell, min_size=50, base_seed=0):
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")

            npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            if not npz_files:
                print(f"[创建数据集] {mode} 数据未找到，开始生成...")
                self.create_dataset_grid_dcsbm(
                    path, mode=mode,
                    snr_grid=snr_grid,
                    gamma_grid=gamma_grid,
                    C_grid=C_grid,
                    per_cell=per_cell,
                    min_size=min_size,
                    base_seed=base_seed
                )
                npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            else:
                print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
            return [os.path.join(path, f) for f in npz_files]

        # ==== 目录 ====
        train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
        test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
        val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"

        train_path = os.path.join(self.path_dataset, train_dir)
        test_path = os.path.join(self.path_dataset, test_dir)
        val_path = os.path.join(self.path_dataset, val_dir)

        # ==== 采用上面的三套参数 ====
        self.data_train = get_npz_dataset(
            train_path, 'train',
            snr_grid=snr_train, gamma_grid=gamma_train, C_grid=[x * (self.n_classes / 2) for x in C_train], per_cell=per_cell_tr,
            min_size=50, base_seed=123
        )
        self.data_val = get_npz_dataset(
            val_path, 'val',
            snr_grid=snr_val, gamma_grid=gamma_val, C_grid=[x * (self.n_classes / 2) for x in C_val], per_cell=per_cell_v,
            min_size=50, base_seed=2025
        )
        self.data_test = get_npz_dataset(
            test_path, 'test',
            snr_grid=snr_test, gamma_grid=gamma_test, C_grid=[x * (self.n_classes / 2) for x in C_test], per_cell=per_cell_te,
            min_size=50, base_seed=31415
        )


    def sample_single(self, i, is_training=True):
        if is_training:
            dataset = self.data_train
        else:
            dataset = self.data_test
        example = dataset[i]
        if (self.generative_model == 'SBM_multiclass'):
            W_np = example['W']
            labels = np.expand_dims(example['labels'], 0)
            labels_var = torch.from_numpy(labels)
            if is_training:
                labels_var.requires_grad = True
            return W_np, labels_var


    def sample_otf_single(self, is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top

    def imbalanced_sample_otf_single(self, class_sizes , is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes, class_sizes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top


    def random_sample_otf_single(self, C = 10 ,is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)

        elif self.generative_model == 'SBM_multiclass':
            a_low, b_low = find_a_given_snr(0.1, self.n_classes, C)
            a_high, b_high = find_a_given_snr(1, self.n_classes, C)

            lower_bound = a_low / b_low
            upper_bound = a_high / b_high

            if lower_bound > upper_bound:
                lower_bound, upper_bound = upper_bound, lower_bound

            p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
                N=N,
                n_classes=self.n_classes,
                C=C,
                alpha_range=(lower_bound, upper_bound),
                min_size= 20
            )
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(p, q, N, self.n_classes, class_sizes)

        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top, snr, class_sizes


    def random_imbalanced_SBM_generator_balanced_sampling(self, N, n_classes, C, *,
                                        alpha_range=(1.3, 2.8),
                                        min_size=5):
        """
        随机生成 SBM 模型的参数，社区大小为随机比例但总和为 N。
        返回 p, q, class_sizes, a, b, snr。
        """
        assert N >= min_size * n_classes

        # Step 1: 随机生成 a > b，使得 a + (k - 1) * b = C
        alpha = np.random.uniform(*alpha_range)
        b = C / (alpha + (n_classes - 1))
        a = alpha * b

        # Step 2: 计算边连接概率
        logn = np.log(N)
        p = a * logn / N
        q = b * logn / N

        # ✅ Step 3: 使用 Dirichlet 生成 class_sizes
        remaining = N - min_size * n_classes
        probs = np.random.dirichlet(np.ones(n_classes))  # 总和为1的概率向量
        extras = np.random.multinomial(remaining, probs)
        class_sizes = [min_size + e for e in extras]

        # Step 4: 计算 SNR
        snr = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))

        return p, q, class_sizes, snr

    def _sample_class_sizes_dirichlet(
            self,
            N: int,
            n_classes: int,
            gamma: float,
            min_size: int,
            rng: Union[int, np.random.Generator],
            gamma_jitter: float = 0.5,
            return_labels: bool = False,  # 新增：是否返回逐点 labels
            shuffle_labels: bool = True,  # 新增：是否打乱节点顺序
            eps: float = 1e-12  # 数值下界，避免 gamma_used 过小
    ) -> Union[list, Tuple[list, np.ndarray, Dict]]:
        """
        按 Dirichlet(gamma) 采样类比例 + 每类至少 min_size。
        - 默认返回 sizes（与旧版兼容）；
        - 若 return_labels=True，同时返回 per-node labels (N,) 和 meta。
        """
        # --- RNG 统一化 ---
        if isinstance(rng, (int, np.integer)):
            rng = np.random.default_rng(int(rng))

        assert N >= min_size * n_classes, "N 必须 >= min_size * n_classes"
        remaining = N - min_size * n_classes

        # --- gamma 抖动并做下界裁剪 ---
        if gamma_jitter and gamma_jitter > 0:
            mult = rng.uniform(max(0.0, 1.0 - gamma_jitter), 1.0 + gamma_jitter)
            gamma_used = max(eps, float(gamma) * float(mult))
        else:
            gamma_used = max(eps, float(gamma))

        alpha = np.full(n_classes, gamma_used, dtype=float)

        # --- 采样类别比例（若 remaining=0，给个均匀兜底） ---
        probs = rng.dirichlet(alpha) if remaining > 0 else np.full(n_classes, 1.0 / n_classes)

        # --- 分配剩余名额 ---
        extras = rng.multinomial(remaining, probs) if remaining > 0 else np.zeros(n_classes, dtype=int)
        sizes = (min_size + extras).astype(int).tolist()

        if not return_labels:
            return sizes  # 与旧代码保持一致

        # --- 展开成逐点标签 (N,) ---
        labels = np.concatenate([np.full(sz, c, dtype=int) for c, sz in enumerate(sizes)])
        if shuffle_labels:
            labels = labels[rng.permutation(N)]

        meta = dict(gamma_used=gamma_used, probs=probs, sizes=np.array(sizes, dtype=int))
        return sizes, labels, meta


    def gen_one_dcsbm_by_targets(
            self, N, n_classes, C, target_snr, gamma, min_size=5, *, rng=None,
            heterophily=False, hetero_prob=None,
            # 与你原来一致的“抖动”参数
            r_jitter=0.05, keep_assortativity=True,
            pq_jitter=None, C_jitter=0.1, C_jitter_mode='relative',
            b_floor=1e-6,
            # 新增：theta 的分布控制
            theta_dist='lognormal', theta_kwargs=None,
            # 新增：是否在社区内归一化 theta
            normalize_theta=True
    ):
        """
        用目标 SNR + Dirichlet(gamma) + C 生成一张 DCSBM 图的参数：
        - 返回：B_prob（块概率矩阵，已含 log n / n），labels，theta，rho，snr_real，a，b，gamma，is_hetero，C_used
        - 你可再用 imbalanced_DCSBM_multiclass 采样得到邻接矩阵
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        # === 0) C 轻微抖动（可选）===
        C_used = float(C)
        if C_jitter and C_jitter > 0:
            if C_jitter_mode == 'relative':
                C_used = C_used * rng.uniform(1.0 - float(C_jitter), 1.0 + float(C_jitter))
            elif C_jitter_mode == 'absolute':
                C_used = C_used + rng.uniform(-float(C_jitter), float(C_jitter))
            C_used = max(C_used, 1e-6)

        # === 1) 由 (target_snr, C_used, k) 解 a,b（与你现有 find_a_given_snr 保持一致）===
        a0, b0 = find_a_given_snr(target_snr, n_classes, C_used)  # 你已有
        r0 = a0 / b0 if b0 > 0 else 1.0

        # === 2) 同/异配选择 ===
        if hetero_prob is not None:
            is_hetero = bool(rng.random() < float(hetero_prob))
        else:
            is_hetero = bool(heterophily)
        if is_hetero:
            a0, b0 = b0, a0
            r0 = a0 / b0 if b0 > 0 else 1.0

        # === 3) r 抖动 & 回代到 a,b，保持 a+(k-1)b = C_used ===
        r = r0
        if r_jitter and r_jitter > 0:
            r *= rng.uniform(1.0 - float(r_jitter), 1.0 + float(r_jitter))
        b = C_used / (r + n_classes - 1.0)
        a = r * b

        # === 4) 保护同/异配属性 ===
        if keep_assortativity:
            if is_hetero:
                if not (a < b):  # 异配：a<b
                    a, b = min(a, b), max(a, b)
            else:
                if not (a > b):  # 同配：a>b
                    a, b = max(a, b), min(a, b)

        # === 4.5) 安全下界，避免 b 太小 ===
        if b < b_floor:
            b = b_floor
            a = C_used - (n_classes - 1) * b

        sizes, labels, meta = self._sample_class_sizes_dirichlet(
            N=N, n_classes=n_classes, gamma=gamma, min_size=min_size, rng=rng,
            return_labels=True, shuffle_labels=True
        )

        # 如果你的 _sample_class_sizes_dirichlet 只返回 sizes，可换成先 sizes 再按顺序展开标签

        # === 6) θ 采样 ===
        if theta_kwargs is None:
            theta_kwargs = {}
        if theta_dist == 'gamma':
            # 缺省：Gamma(k=2, θ=0.5) → 均值 ~1，较灵活
            shape = float(theta_kwargs.get('shape', 2.0))
            scale = float(theta_kwargs.get('scale', 0.5))
            theta = rng.gamma(shape=shape, scale=scale, size=N)
        elif theta_dist == 'lognormal':
            mu = float(theta_kwargs.get('mu', -0.32))
            sigma = float(theta_kwargs.get('sigma', 0.8))
            theta = rng.lognormal(mean=mu, sigma=sigma, size=N)
        else:
            raise ValueError(f"Unsupported theta_dist: {theta_dist}")

        # 社区内归一化（推荐）：sum_{i in r} theta_i = n_r
        if normalize_theta:
            theta = _normalize_theta_within_block(labels, theta)

        # === 7) 构造 B_prob：含 log n / n 的概率尺度 ===
        logn = np.log(N)
        scale = logn / N
        B_prob = np.full((n_classes, n_classes), b * scale, dtype=float)
        np.fill_diagonal(B_prob, a * scale)

        # （可选）对 p,q 再小幅抖动
        if pq_jitter is not None:
            pj, qj = pq_jitter
            if pj and pj > 0:
                B_prob[np.eye(n_classes, dtype=bool)] *= rng.uniform(1.0 - float(pj), 1.0 + float(pj))
            if qj and qj > 0:
                off = ~np.eye(n_classes, dtype=bool)
                B_prob[off] *= rng.uniform(1.0 - float(qj), 1.0 + float(qj))

        # === 8) 计算把“期望平均度”拉到 target 的 ρ ===
        # 目标平均度（与 SBM 一致）：d_target ≈ C_used * log n
        d_target = C_used * logn

        # 计算 m_r = sum_{i in block r} theta_i
        m = np.array([theta[labels == r].sum() for r in range(n_classes)], dtype=float)
        # 平均度（不含 ρ）的模型预测：d_model = (1/n) m^T B_prob m
        d_model = (m @ B_prob @ m) / N
        rho = d_target / max(d_model, 1e-12)

        # 概率不应超过 1：给出一个安全上界
        # 构造每个节点的“基线系数” s_i = sum_s B_prob[g_i,s] * m_s
        # E[d_i]/rho = theta_i * s_i
        # 要保证 rho * theta_i * s_i <= 1 的数量级安全，可保守地裁切
        # （采样时我们还会 clip 到 [0,1]）
        return B_prob, labels, theta, rho, (a - b) ** 2 / (
                    n_classes * (a + (n_classes - 1) * b)), a, b, gamma, is_hetero, C_used

    def create_dataset_grid_dcsbm(self, directory, mode='train', *,
                                  snr_grid=(0.6, 0.9, 1.1, 1.3, 1.6, 2.0, 2.5, 3.0),
                                  gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
                                  C_grid=(10.0,),
                                  per_cell=20,
                                  min_size=5,
                                  base_seed=0,
                                  # DCSBM 额外可控
                                  theta_dist='lognormal', theta_kwargs=None,
                                  normalize_theta=True,
                                  return_eigvecs=True, topk=8):
        """
        在 (SNR × gamma × C) 网格上生成 DCSBM 数据；每格 per_cell 张图。
        """
        import os
        from scipy.sparse import csr_matrix
        os.makedirs(directory, exist_ok=True)

        if mode == 'train':
            N = self.N_train
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_train = directory
        elif mode == 'val':
            N = self.N_val
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_val = directory
        elif mode == 'test':
            N = self.N_test
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_test = directory
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        idx = 0
        for c_idx, C in enumerate(C_grid):
            for s_idx, snr_target in enumerate(snr_grid):
                for g_idx, gamma in enumerate(gamma_grid):
                    cell_seed = base_seed + (c_idx * 10_000_000 + s_idx * 10_000 + g_idx * 100)
                    rng = np.random.default_rng(cell_seed)

                    for rep in range(per_cell):
                        # === 核心改动：用 DCSBM 的参数生成器 ===
                        (B_prob, labels, theta, rho,
                         snr_real, a, b, gamma_val, is_hetero, C_used) = self.gen_one_dcsbm_by_targets(
                            N=N, n_classes=self.n_classes, C=C,
                            target_snr=snr_target, gamma=gamma,
                            min_size=min_size, rng=rng,
                            # 抖动参数承接原来
                            r_jitter=0.05, keep_assortativity=True,
                            pq_jitter=None, C_jitter=0.1, C_jitter_mode='relative',
                            b_floor=1e-6,
                            # theta 控制
                            theta_dist=theta_dist, theta_kwargs=theta_kwargs,
                            normalize_theta=normalize_theta
                        )

                        # === 采样 DCSBM 图 ===
                        W_dense, labels_out, eigvecs_top = self.imbalanced_DCSBM_multiclass(
                            B_prob=B_prob, labels=labels, theta=theta, rho=rho,
                            rng=rng, return_eigvecs=return_eigvecs, topk=topk
                        )
                        W_sparse = csr_matrix(W_dense)

                        fname = (f"{mode}_i{idx:05d}"
                                 f"__C{C:.2f}__snr{snr_target:.3f}"
                                 f"__g{gamma:.3f}__rep{rep:02d}.npz")
                        path = os.path.join(directory, fname)

                        # === 存盘：保留 DCSBM 关键元数据（θ、ρ、B_prob 等）===
                        np.savez_compressed(
                            path,
                            adj_data=W_sparse.data,
                            adj_indices=W_sparse.indices,
                            adj_indptr=W_sparse.indptr,
                            adj_shape=W_sparse.shape,
                            labels=labels_out.astype(np.int32),
                            # 记录 a,b,C, snr
                            a=a, b=b, C=C, C_used=C_used,
                            snr_target=snr_target, snr_real=snr_real,
                            gamma=gamma_val,
                            # DCSBM 关键
                            theta=theta.astype(np.float32),
                            rho=float(rho),
                            B_prob=B_prob.astype(np.float32),
                            # （可选）特征向量
                            eigvecs_top=eigvecs_top if return_eigvecs else None
                        )
                        idx += 1

        print(f"[{mode}] (DCSBM) 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)